MNIST Geodesics - 3 digits

In this notebook, we use the first 3 MNIST digits to compare the different approaches we have to approximate the Riemannian distance:

  1. Riemannian length of the Euclidean interpolation
  2. Discrete Geodesic Algorithm from Shao et al. (2017)
  3. ODE from the Latent Space Oddity paper
  4. Shortest path in the latent graph
  5. Using the shortest path as an initialization for the discrete geodesic algorithm

Conclusion

Stochastic Riemannian length of shortest curves found:

  • Method 1 (Euclidean Interpolation): 23.7
  • Method 2 (Discrete Geodesic): 21.8
  • Method 3 (ODE): 21.8
  • Method 4 (Graph): 23.0
  • Method 5 (Graph as initialization for discrete): 21.8

Time per 100 geodesic approximations:

  • Method 1 (Euclidean Interpolation): 3s
  • Method 2 (Discrete Geodesic): 4min 47s
  • Method 3 (ODE): 100 x 16min 33s
  • Method 4 (Graph): 4s
  • Method 5 (Graph as initialization for discrete): 4min 37s

Imports and setup of plotting library

In [1]:
%load_ext autoreload
%autoreload 2
import numpy as np
import tensorflow as tf
from copy import deepcopy
import plotly.graph_objs as go
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot

# Set up plotly
init_notebook_mode(connected=True)
layout = go.Layout(
    width=700,
    height=500,
    margin=go.Margin(l=60, r=60, b=40, t=20),
    showlegend=False
)
config={'showLink': False}

# Make results completely reproducible. Use the seed that the Latent Space 
# Oddity experiment in 10-mnist-digits-012.ipynb used for comparable results
seed = 9
np.random.seed(seed)
tf.set_random_seed(seed)

digit_classes = [0,1,2]
/Users/kilian/dev/tum/2018-mlic-kilian/venv/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters

Create the VAE

following the implementation details in appendix D in the Latent Space Oddity paper.

In [2]:
from tensorflow.python.keras import Sequential, Model
from tensorflow.python.keras.layers import Dense, Input, Lambda
from src.vae import VAE
from src.rbf import RBFLayer

# Implementation details from Appendix D
input_dim = 784
latent_dim = 2
l2_reg = tf.keras.regularizers.l2(1e-5)

# Create the encoder models
enc_input = Input((input_dim,))
enc_shared = Dense(64, activation='tanh', kernel_regularizer=l2_reg)
enc_mean = Sequential([
    enc_shared,
    Dense(32, activation='tanh', kernel_regularizer=l2_reg),
    Dense(latent_dim, activation='linear', kernel_regularizer=l2_reg)
])
enc_var = Sequential([
    enc_shared,
    Dense(32, activation='tanh', kernel_regularizer=l2_reg),
    Dense(latent_dim, activation='softplus', kernel_regularizer=l2_reg)
])
enc_mean = Model(enc_input, enc_mean(enc_input))
enc_var = Model(enc_input, enc_var(enc_input))

# Create the decoder models
dec_input = Input((latent_dim,))
dec_mean = Sequential([
    Dense(32, activation='tanh', kernel_regularizer=l2_reg),
    Dense(64, activation='tanh', kernel_regularizer=l2_reg),
    Dense(input_dim, activation='sigmoid', kernel_regularizer=l2_reg)
])
dec_mean = Model(dec_input, dec_mean(dec_input))

# Build the RBF network
num_centers = 64
a = 2.0
rbf = RBFLayer([input_dim], num_centers)
dec_var = Model(dec_input, rbf(dec_input))

vae = VAE(enc_mean, enc_var, dec_mean, dec_var, dec_stddev=1.)
WARNING:tensorflow:Output "model_6" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "model_6" during training.
WARNING:tensorflow:Output "model_6" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "model_6" during training.
WARNING:tensorflow:Output "model_6" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "model_6" during training.

Filter the digits from MNIST

In [3]:
from tensorflow.python.keras.datasets import mnist

# Train the VAE on MNIST digits
(x_train_all, y_train_all), _ = mnist.load_data()

# Filter the digit classes from the mnist data
x_train = []
y_train = []
for digit_class in digit_classes:
    count = 0
    for x, y in zip(x_train_all, y_train_all):
        if y == digit_class:
            x_train.append(x)
            y_train.append(y)
            count += 1
            if count == 1000:
                break
                
x_train = np.array(x_train).astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
y_train = np.array(y_train)

# Shuffle the data
p = np.random.permutation(len(x_train))
x_train = x_train[p]
y_train = y_train[p]

Train the VAE

without training the generator's variance network. This will be trained separately later.

In [4]:
history = vae.model.fit(x_train,
              epochs=300,
              batch_size=32,
              validation_split=0.1,
              verbose=0)

# Plot the losses
data = [go.Scatter(y=history.history['loss'], name='Train Loss'),
       go.Scatter(y=history.history['val_loss'], name='Validation Loss')]
plot_layout = deepcopy(layout)
plot_layout['xaxis'] = {'title': 'Epoch'}
plot_layout['yaxis'] = {'title': 'NELBO'}
plot_layout['showlegend'] = True
iplot(go.Figure(data=data, layout=plot_layout), config=config)

Visualize the latent space

In [5]:
# Display a 2D plot of the classes in the latent space
encoded_sampled, encoded_mean, encoded_var = vae.encoder.predict(x_train)

# Plot
scatter_data = []
colors = ['#F7AB3D', '#825446', '#2D4366']
for digit_class in digit_classes:
    y_class = [y == digit_class for y in y_train]
    indices = np.arange(len(y_class))[y_class]
    x_class = encoded_mean[y_class]
    scatter_data.append(go.Scatter(
        x = x_class[:, 0],
        y = x_class[:, 1],
        mode = 'markers',
        marker = {'color': colors[len(scatter_data)]},
        name = digit_class,
        hoverinfo = 'text',
        text = indices
    ))
iplot(go.Figure(data=scatter_data, layout=layout), config=config)

Train the generator's variance network

For this, we first have to find the centers of the latent points.

In [6]:
from sklearn.cluster import KMeans

# Find the centers of the latent representations
kmeans_model = KMeans(n_clusters=num_centers, random_state=0)
kmeans_model = kmeans_model.fit(encoded_mean)
centers = kmeans_model.cluster_centers_

# Visualize the centers
center_plot = go.Scatter(
    x = centers[:, 0],
    y = centers[:, 1],
    mode = 'markers',
    marker = {'color': 'red'}
)
data = scatter_data + [center_plot] 
iplot(go.Figure(data=data, layout=layout), config=config)

Compute the bandwidths

In [7]:
# Cluster the latent representations
clustering = dict((c_i, []) for c_i in range(num_centers))
for z_i, c_i in zip(encoded_mean, kmeans_model.predict(encoded_mean)):
    clustering[c_i].append(z_i)
    
bandwidths = []
for c_i, cluster in clustering.items():
    if cluster:
        diffs = np.array(cluster) - centers[c_i]
        avg_dist = np.mean(np.linalg.norm(diffs, axis=1))
        bandwidth = 0.5 / (a * avg_dist)**2
    else:
        bandwidth = 0
    bandwidths.append(bandwidth)
bandwidths = np.array(bandwidths)

Train the variance network

In [8]:
# Train the RBF
vae.recompile_for_var_training()
rbf_kernel = rbf.get_weights()[0]
rbf.set_weights([rbf_kernel, centers, bandwidths])

history = vae.model.fit(x_train,
                        epochs=300,
                        batch_size=32,
                        validation_split=0.1,
                        verbose=0)

# Plot the losses
data = [go.Scatter(y=history.history['loss'],
                   name='Train Loss'),
        go.Scatter(y=history.history['val_loss'],
                   name='Validation Loss')]
plot_layout = deepcopy(layout)
plot_layout['xaxis'] = {'title': 'Epoch'}
plot_layout['yaxis'] = {'title': 'NELBO'}
plot_layout['showlegend'] = True
iplot(go.Figure(data=data, layout=plot_layout), config=config)
WARNING:tensorflow:Output "model_6" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "model_6" during training.
WARNING:tensorflow:Output "model_6" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "model_6" during training.
WARNING:tensorflow:Output "model_6" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "model_6" during training.
In [9]:
from src.util import wrap_model_in_float64

# Get the mean and std predictors
_, mean_output, var_output = vae.decoder.output
sqrt_layer = Lambda(tf.sqrt)
dec_mean = Model(vae.decoder.input, mean_output)
dec_std = Model(vae.decoder.input, sqrt_layer(var_output))
dec_mean = wrap_model_in_float64(dec_mean)
dec_std = wrap_model_in_float64(dec_std)

session = tf.keras.backend.get_session()

Choose two latent points

for finding a geodesic.

In [10]:
z_start, z_end = encoded_mean[[2102, 1091]]

# Visualize the centers
task_plot = go.Scatter(
    x = [z_start[0], z_end[0]],
    y = [z_start[1], z_end[1]],
    mode = 'markers',
    marker = {'color': 'd32f2f'}
)
data = scatter_data + [task_plot]
iplot(go.Figure(data=data, layout=layout), config=config)

Plot the magnification factors

In [11]:
from src.plot import plot_magnification_factor

heatmap_z1 = np.linspace(-4, 4, 100)
heatmap_z2 = np.linspace(-4, 4, 100)
heatmap = plot_magnification_factor(session, 
                                    heatmap_z1,
                                    heatmap_z2, 
                                    dec_mean, 
                                    dec_std, 
                                    additional_data=scatter_data + [task_plot],
                                    layout=layout,
                                    log_scale=True)
Computing Magnification Factors: 100%|██████████| 500/500 [00:01<00:00, 360.54it/s]

Define the evaluation metric

Before we start comparing the geodesic approximations, we need to define the metric. For each curve, we take equidistant steps in the latent space in order to compute the Riemannian length using numerical integration. We also plot the curve velocity.

In [12]:
from src.util import get_length_op, get_lengths_op, interpolate

curve_ph = tf.placeholder(tf.float64, [None, 2])
length_op, _ = get_length_op(curve_ph, dec_mean, dec_std)
lengths_op = get_lengths_op(curve_ph, dec_mean, dec_std)
lengths_op = tf.squeeze(lengths_op)

def evaluate_curve(curve, num_nodes=200, with_velocity_plot=True, 
                   verbose=True):
    curve = interpolate(curve, num_nodes)
    lengths = session.run(lengths_op, feed_dict={curve_ph: curve})
    length = np.sum(lengths)
    if verbose:
        print('Curve length: ', length)
    
    if with_velocity_plot:
        plot_velocity(lengths)
        
    return length
    
def plot_velocity(lengths):
    num_nodes = len(lengths)
    velocities = lengths * (num_nodes - 1)
    trace = go.Scatter(
        x = np.linspace(0, 1, num_nodes),
        y = velocities
    )
    iplot(go.Figure(data=[trace], layout=go.Layout(
        width=700,
        height=100,
        margin=go.Margin(l=60, r=60, b=20, t=20),
        showlegend=False
    )), config=config)

Method 1 - Euclidean Interpolation

In [13]:
t_nodes = np.linspace(0, 1, 50)
euclidean_curve = z_start + np.outer(t_nodes, z_end - z_start)
In [14]:
evaluate_curve(euclidean_curve)
Curve length:  23.735611661291685
Out[14]:
23.735611661291685

Method 2 - Discrete Geodesics

In [15]:
%%time
from src.discrete import find_geodesic_discrete

discrete_curve, discrete_iterations = find_geodesic_discrete(
    session, 
    z_start, z_end, 
    dec_mean, 
    std_generator=dec_std,
    num_nodes=50,
    max_steps=400,
    learning_rate=0.01,
    log_every=50,
    save_every=30)

print('-' * 20)
Step 0, Length 23.731371, Energy 331.648435, Max velocity ratio 7.501815
Step 50, Length 22.648713, Energy 274.766113, Max velocity ratio 2.176842
Step 100, Length 22.300530, Energy 257.620281, Max velocity ratio 1.654214
Step 150, Length 21.858387, Energy 243.167057, Max velocity ratio 1.435469
Step 200, Length 21.784006, Energy 238.576237, Max velocity ratio 1.249839
Step 250, Length 21.780728, Energy 237.500067, Max velocity ratio 1.144518
Step 300, Length 21.778819, Energy 237.249160, Max velocity ratio 1.075147
Step 350, Length 21.778341, Energy 237.183367, Max velocity ratio 1.056452
Step 400, Length 21.778118, Energy 237.173369, Max velocity ratio 1.048920
--------------------
CPU times: user 12.8 s, sys: 4.49 s, total: 17.3 s
Wall time: 6.35 s
In [16]:
from src.plot import plot_latent_curve_iterations

plot_latent_curve_iterations(discrete_iterations, [heatmap] + scatter_data, 
                             layout, step_size=30)

Show the effect of the curve normalizer

In [17]:
curve_test = interpolate(discrete_curve, 20)
plot_latent_curve_iterations([curve_test], [heatmap] + scatter_data, layout)
In [18]:
evaluate_curve(discrete_curve)
Curve length:  21.81436024514179
Out[18]:
21.81436024514179

Method 3 - ODE

In [19]:
%%time
from src.geodesic import find_geodesic

ode_result, ode_iterations = find_geodesic(session, z_start, z_end, 
                                           dec_mean, std_generator=dec_std, 
                                           initial_nodes=20, max_nodes=1000,
                                           use_fun_jac=True)
print('-' * 20)
   Iteration    Max residual    Total nodes    Nodes added  
       1          1.70e+00          20             36       
       2          4.31e+00          56             91       
       3          5.50e+00          147            228      
       4          9.19e-04          375             0       
Solved in 4 iterations, number of nodes 375, maximum relative residual 9.19e-04.
--------------------
CPU times: user 26min 38s, sys: 12min 33s, total: 39min 12s
Wall time: 16min 33s
In [20]:
plot_latent_curve_iterations(ode_iterations[::10], [heatmap] + scatter_data, 
                             layout, step_size=10)
In [21]:
ode_curve = ode_result.sol(ode_result.x)[0:2].T
evaluate_curve(ode_curve)
Curve length:  21.81323380600473
Out[21]:
21.81323380600473

Method 4 - Graph

We use the 1000 points per digit and add twice as many random gaussian noise points. This gives a total of 9,000 points in the latent space, which we will use for our graph in the latent space.

In [22]:
graph_points = encoded_mean
extensions = [graph_points + np.random.randn(*graph_points.shape) 
              for _ in range(2)]
graph_points = np.concatenate([graph_points] + extensions)
print(graph_points.shape)
(9000, 2)

Compute the Riemannian distances of neighboring points

To get the nearest neighbors of each point, we use the get_neighbors function from src.graph. It is explained and defined in 13-graph-geodesics.ipynb.

Given the get_neighbors function, compute the Riemannian distance between each point and each of its neighbors. We approximate the Riemannian distance with a single midpoint for integration: $\int_0^1 \left\| J_{\gamma_t} \dot{\gamma}_t \right\| \mathrm{d}t \approx \left\| J_{\gamma_t} \dot{\gamma}_t \right\|$

In [23]:
import networkx as nx
from tqdm import tqdm

from src.util import get_metric_op
from src.graph import get_neighbors

point_ph = tf.placeholder(tf.float64, [2])
metric_op = get_metric_op(point_ph, dec_mean, dec_std)

# Compute the distance between the kNNs in Euclidean space
k = 4
graph = nx.Graph()
for i_point, point in enumerate(graph_points):
    graph.add_node(i_point, pos=point)
    
for i_point, point in enumerate(tqdm(graph_points)):
    neighbor_indices = get_neighbors(i_point, graph_points, k)
    
    for i_neighbor in neighbor_indices:
        if graph.has_edge(i_neighbor, i_point): 
            continue
        
        neighbor = graph_points[i_neighbor]
        middle = point + 0.5 * (neighbor - point) 
        velocity = neighbor - point
        metric = session.run(metric_op, feed_dict={point_ph: middle})
        length = velocity.T.dot(metric).dot(velocity)
        length = np.sqrt(length)
        graph.add_edge(i_point, i_neighbor, weight=length) 
100%|██████████| 9000/9000 [00:47<00:00, 188.82it/s]

Visualize a subgraph

and the relative weight of the edges (Riemannian length divided by Euclidean length). Green means a low relative weight, red means a large relative weight.

In [24]:
from src.plot import plot_graph_with_edge_colors

x_range = [-1., 1.]
y_range = [-1., 1.]

subnodes = []
for node in graph.nodes():
    pos = graph.node[node]['pos']
    if (x_range[0] <= pos[0] <= x_range[1] and
        y_range[0] <= pos[1] <= y_range[1]):
        subnodes.append(node)

subgraph = graph.subgraph(subnodes)
graph_plot = plot_graph_with_edge_colors(graph, layout=layout)

Compute the shortest path

between the two points from above.

In [25]:
z_start_index = np.where(graph_points == z_start)[0][0]
z_end_index = np.where(graph_points == z_end)[0][0]
In [26]:
%%time
from networkx.algorithms.shortest_paths.generic import shortest_path
path = shortest_path(graph, z_start_index, z_end_index, weight='weight')
length = 0
for source, sink in zip(path[:-1], path[1:]):
    length += graph[source][sink]['weight']
print('Path length:', length)
print('-' * 20)
Path length: 23.10135662424389
--------------------
CPU times: user 79.2 ms, sys: 2.21 ms, total: 81.4 ms
Wall time: 80.4 ms

Visualize the shortest path

In [27]:
from src.plot import plot_graph

# Construct a subgraph from the path
path_graph = nx.Graph()
for point in path:
    path_graph.add_node(point, pos=graph_points[point])
for source, sink in zip(path[:-1], path[1:]):
    weight = graph[source][sink]['weight']
    path_graph.add_edge(source, sink, weight=weight) 

_ = plot_graph(path_graph, layout=layout, edge_color='#00DD00', 
               node_color='#00DD00', additional_data=[heatmap] + scatter_data)

Measure the actual curve length

Since we only computed the Riemannian distance for each edge using a single midpoint, the graph length is not exactly correct. It is not as strongly biased as the discrete geodesic algorithm's length estimate, but we should measure it as well with the interpolate function for a fair comparison.

In [28]:
graph_curve = graph_points[path]
evaluate_curve(graph_curve)
Curve length:  23.030336804283685
Out[28]:
23.030336804283685

Method 5 - Graph as init to discrete algorithm

In [29]:
graphref_curve, _ = find_geodesic_discrete(
    session, 
    z_start, z_end, 
    dec_mean, 
    std_generator=dec_std,
    num_nodes=50,
    max_steps=400,
    learning_rate=0.01,
    log_every=50,
    curve_init=graph_curve)
Step 0, Length 22.828364, Energy 340.070491, Max velocity ratio 19.046716
Step 50, Length 21.813992, Energy 260.445630, Max velocity ratio 2.448730
Step 100, Length 21.816412, Energy 248.705694, Max velocity ratio 1.787939
Step 150, Length 21.793488, Energy 241.658040, Max velocity ratio 1.458299
Step 200, Length 21.784169, Energy 238.486223, Max velocity ratio 1.240579
Step 250, Length 21.780693, Energy 237.502750, Max velocity ratio 1.143958
Step 300, Length 21.778794, Energy 237.228808, Max velocity ratio 1.070841
Step 350, Length 21.778801, Energy 237.219511, Max velocity ratio 1.066448
Step 400, Length 21.778306, Energy 237.178999, Max velocity ratio 1.054801
In [30]:
evaluate_curve(graphref_curve)
Curve length:  21.81447085836175
Out[30]:
21.81447085836175

Visualize the refined graph solution

Together with the shortest path (blue) and the standard discrete geodesic solution (red). The refined graph solution (green) is not visible because it is exactly the same as the red curve.

In [31]:
# Plot the graph curve
graph_curve_plot = go.Scatter(
    x=graph_curve[:, 0],
    y=graph_curve[:, 1],
    mode='lines',
    line={'width': 5, 'color': '#3CA8FF'}
)
# Plot the refined graph curve
graphref_curve_plot = go.Scatter(
    x=graphref_curve[:, 0],
    y=graphref_curve[:, 1],
    mode='lines',
    line={'width': 5, 'color': 'green'}
)
# Plot the discrete curve
discrete_curve_plot = go.Scatter(
    x=discrete_curve[:, 0],
    y=discrete_curve[:, 1],
    mode='lines',
    line={'width': 5, 'color': '#d32f2f'}
)
data = [heatmap] + scatter_data + [graph_curve_plot, graphref_curve_plot,
        discrete_curve_plot, task_plot]
iplot(go.Figure(data=data, layout=layout), config=config)

Multiple Points Benchmark

Measure the runtime of each approach on 100 random pairs of points. We don't use the ODE here, since it takes orders of magnitude longer than the discrete geodesic algorithm without giving better geodesic approximations.

In [32]:
z_starts = np.random.choice(len(encoded_mean), 100)
z_ends = np.random.choice(len(encoded_mean), 100)

Method 1 - Euclidean Interpolation

In [33]:
def test_euclidean(z_starts, z_ends):
    curve_ph = tf.placeholder(tf.float64, [None, 2])
    length_op = get_length_op(curve_ph, dec_mean, dec_std)
    curves = []
    lengths = []

    for z_start, z_end in zip(encoded_mean[z_starts], encoded_mean[z_ends]):
        t_nodes = np.linspace(0, 1, 20)
        curve = z_start + np.outer(t_nodes, z_end - z_start)
        length, _ = session.run(length_op, feed_dict={curve_ph: curve})
        curves.append(curve)
        lengths.append(length)
    return curves, lengths

Method 2 - Discrete Geodesics

In [34]:
from src.discrete import find_geodesics_discrete

def test_discrete(z_starts, z_ends):
    return find_geodesics_discrete(
        session, 
        encoded_mean[z_starts], encoded_mean[z_ends], 
        dec_mean, 
        std_generator=dec_std,
        num_nodes=50,
        max_steps=200,
        learning_rate=0.01,
        return_iterations=True)

Method 4 - Graph

In [35]:
def test_graph(z_starts, z_ends):
    curves = []
    lengths = []
    for z_start, z_end in zip(z_starts, z_ends):
        path = shortest_path(graph, z_start, z_end, weight='weight')
        curve = graph_points[path]
        length = 0
        for source, sink in zip(path[:-1], path[1:]):
            length += graph[source][sink]['weight']
        curves.append(curve)
        lengths.append(length)
    return curves, lengths

Method 5 - Graph with refinement

In [36]:
def test_graph_refinement(z_starts, z_ends):
    curve_inits = []
    for z_start, z_end in zip(z_starts, z_ends):
        path = shortest_path(graph, z_start, z_end, weight='weight')
        curve = graph_points[path]
        curve_inits.append(curve)
    return find_geodesics_discrete(
        session, 
        encoded_mean[z_starts], encoded_mean[z_ends], 
        dec_mean, 
        std_generator=dec_std,
        num_nodes=50,
        max_steps=200,
        learning_rate=0.01,
        curve_inits=curve_inits,
        return_iterations=True)

Measure the time per geodesic

In [37]:
%%time
eucl_curves, eucl_est_lengths = test_euclidean(z_starts, z_ends)
print('-' * 20)
--------------------
CPU times: user 3.65 s, sys: 173 ms, total: 3.82 s
Wall time: 3.3 s
In [38]:
%%time
discrete_curves, discrete_est_lengths, discrete_iterations = test_discrete(z_starts, z_ends)
print('-' * 20)
--------------------
CPU times: user 9min 48s, sys: 3min 42s, total: 13min 30s
Wall time: 4min 47s
In [39]:
%%time
graph_curves, graph_est_lengths = test_graph(z_starts, z_ends)
print('-' * 20)
--------------------
CPU times: user 4.72 s, sys: 12.2 ms, total: 4.73 s
Wall time: 4.73 s
In [40]:
%%time
graphref_curves, graphref_est_lengths, graphref_iterations = test_graph_refinement(z_starts, z_ends)
print('-' * 20)
--------------------
CPU times: user 9min 13s, sys: 3min 31s, total: 12min 44s
Wall time: 4min 37s

Measure the actual lengths

In [41]:
def evaluate_curves(curves, estimated_lengths, num_nodes=200):
    lengths = []
    estimation_errors = []
    for curve, estimated_length in zip(curves, estimated_lengths):
        curve = interpolate(curve, num_nodes)
        length = session.run(length_op, feed_dict={curve_ph: curve})
        lengths.append(length)
        estimation_errors.append(estimated_length - length)
        
    print('Estimation error mean: ', np.mean(estimation_errors))
    print('Estimation error std: ', np.std(estimation_errors))
    return lengths
In [42]:
eucl_lengths = evaluate_curves(eucl_curves, eucl_est_lengths)
Estimation error mean:  -0.025323633525505533
Estimation error std:  0.18010913906151618
In [43]:
discrete_lengths= evaluate_curves(discrete_curves, discrete_est_lengths)
Estimation error mean:  -0.17321350697869367
Estimation error std:  1.0584204310284133
In [44]:
graph_lengths = evaluate_curves(graph_curves, graph_est_lengths)
Estimation error mean:  0.11040795248670482
Estimation error std:  0.13813771883229867
In [45]:
graphref_lengths = evaluate_curves(graphref_curves, graphref_est_lengths)
Estimation error mean:  -0.15509482644755765
Estimation error std:  0.9294213459348353

Plot the lengths

In [46]:
eucl_trace = go.Scatter(
    x = graph_lengths,
    y = np.array(eucl_lengths),
    mode = 'markers',
    marker = {'size': 8, 'symbol': 'x', 'color': 'orange'},
    name = 'Euclidean'
)
graph_trace = go.Scatter(
    x = graph_lengths,
    y = np.array(graph_lengths),
    mode = 'markers',
    marker = {'size': 8, 'symbol': 'x', 'color': '#3CA8FF'},
    name = 'Graph'
)
discrete_trace = go.Scatter(
    x = graph_lengths,
    y = np.array(discrete_lengths),
    mode = 'markers',
    marker = {'size': 8, 'symbol': 'x', 'color': '#d32f2f'},
    name = 'Discrete'
)
graphref_trace = go.Scatter(
    x = graph_lengths,
    y = np.array(graphref_lengths),
    mode = 'markers',
    marker = {'size': 8, 'symbol': 'x', 'color': '#2D4366'},
    name = 'Graph Refinement'
)
data = [eucl_trace, graph_trace, discrete_trace, graphref_trace]
_layout = go.Layout(
    width=800,
    height=600,
    margin=go.Margin(l=60, r=60, b=40, t=20),
    xaxis={
        'title': 'Length of graph solution',
        'titlefont': {'size': 18}
    },
    yaxis={
        'title': 'Stochastic Riemannian length',
        'titlefont': {'size': 18}
    },
    legend={
        'font': {'size': 18}
    }
)
iplot(go.Figure(data=data, layout=_layout), config=config)

Conclusion

Stochastic Riemannian length of shortest curves found:

  • Method 1 (Euclidean Interpolation): 23.7
  • Method 2 (Discrete Geodesic): 21.8
  • Method 3 (ODE): 21.8
  • Method 4 (Graph): 23.0
  • Method 5 (Graph as initialization for discrete): 21.8

Time per 100 geodesic approximations:

  • Method 1 (Euclidean Interpolation): 3s
  • Method 2 (Discrete Geodesic): 4min 47s
  • Method 3 (ODE): 100 x 16min 33s
  • Method 4 (Graph): 4s
  • Method 5 (Graph as initialization for discrete): 4min 37s